# encoding:utf-8
import psycopg2
import pandas as pd
import argparse
import io
import pymysql
import json
import numpy as np
from config.config import *

def save_gp(file_path, table_name):
	df = pd.read_csv(file_path)
	row_num = df.shape[0]
	conn = psycopg2.connect(dbname=GP_DBNAME,
	                        user=GP_USER,
	                        password=GP_PASSWORD,
	                        host=GP_HOST,
	                        port=GP_PORT)
	cursor = conn.cursor()

	try:
		sql = "drop table if exists {}".format(table_name)
		cursor.execute(sql)

		df_columns = list(df.columns)

		col_to_del = []
		for i in range(len(df_columns)):
			if "Unnamed: 0" in df_columns[i]:
				col_to_del.append(df_columns[i])
		if "_record_id_" not in list(df.columns):
			df['_record_id_'] = np.arange(1, row_num + 1)
			df_columns = list(df.columns)

		for col in col_to_del:
			del df[col]
			df_columns = list(df.columns)

		dtypes = []
		for dtype in df.dtypes:
			if dtype == float:
				dtypes.append("float4")
			elif dtype == int:
				dtypes.append("integer")
			else:
				dtypes.append("varchar")

		columns = ["\"" + column + "\"" + " " + dtype for column, dtype in zip(df.columns, dtypes)]

		sql = "create table {}({})".format(table_name, ",".join(columns))

		cursor.execute(sql)
		data_io = io.StringIO()
		df.to_csv(data_io, sep="|", index=False)
		data_io.seek(0)
		# data_io.readline()  # remove header DO NOT DELETE THIS COMMENT
		copy_cmd = "COPY %s FROM STDIN HEADER DELIMITER '|' CSV" % table_name
		cursor.copy_expert(copy_cmd, data_io)
	except Exception as e:
		print(e)

	conn.commit()
	cursor.close()
	conn.close()
	return df_columns, dtypes, row_num


def get_mysql_cursor():
	return pymysql.connect(db=MYSQL_DBNAME,
	                       user=MYSQL_USER,
	                       password=MYSQL_PASSWORD,
	                       host=MYSQL_HOST,
	                       port=int(MYSQL_PORT))


def read_mysql(sql):
	conn = get_mysql_cursor()
	cursor = conn.cursor()
	cursor.execute(sql)
	meta = cursor.fetchall()
	columns = [desc[0] for desc in cursor.description]
	df = pd.DataFrame(meta, columns=columns)
	cursor.close()
	conn.close()
	return df


def get_task_instance_data_json(task_instance_id):
	series = read_mysql("select * from task_instance where id = '{}'".format(task_instance_id)).iloc[-1]
	data_json = json.loads(series["data_json"])
	return data_json


def execute_mysql(sql):
	conn = get_mysql_cursor()
	cursor = conn.cursor()
	cursor.execute(sql)
	conn.commit()
	cursor.close()
	conn.close()


def py_to_java(s):
	return s.replace("'", '"') \
		.replace("False", "false") \
		.replace("True", "true") \
		.replace(", ", ",") \
		.replace(": ", ":") \
		.replace("None", "null") \
		.replace("nan", "null") \
		# .replace("inf", "2147483647")


def update_task_instance(data_json, task_instance_id):
	data = py_to_java(str(data_json))
	data = "'" + data + "'"
	sql = "update task_instance set data_json={} where id={} ".format(data, task_instance_id)
	execute_mysql(sql)


def update_data_json(task_instance_id, df_columns, dtypes, table_name, row_num):
	data_json = get_task_instance_data_json(task_instance_id)
	semantics = {}
	for col in (df_columns):
		semantics[col] = "null"
	data_json_output = {
		"tableCols": df_columns
		, "columnTypes": dtypes
		, "tableName": table_name
		, "semantic": semantics
		, "totalRow": row_num
	}
	data_json["inputInfo"]["output"] = [data_json_output]
	update_task_instance(data_json, task_instance_id)


if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument("--file_path", type=str, required=True)
	parser.add_argument("--table_name", type=str, required=True)
	parser.add_argument("--task_instance_id", type=int, required=True)

	args = parser.parse_args()

	df_columns, dtypes, row_num = save_gp(args.file_path, args.table_name)

	update_data_json(args.task_instance_id, df_columns, dtypes, args.table_name, row_num)
